from datasets.wavallin import _process_utterance as process_utterance
import glob
from hparams import hparams
from IPython.display import Audio, display, HTML
import librosa
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm
DPI = 100
!rm -rf features
!mkdir features
generated_examples_paths = []
for subdir in ['dev', 'eval']:
generated_examples_paths.extend(glob.glob(os.path.join('generated',
'checkpoint_latest',
subdir, '*_gen.wav')))
for path in tqdm(generated_examples_paths):
process_utterance('features', 0, path, '')
100%|██████████| 10/10 [00:01<00:00, 7.68it/s]
def load_loss(path):
steps = []
losses = []
for line in open(path).readlines():
step, loss = line[:-1].split(',')
step, loss = int(step), float(loss)
steps.append(step)
losses.append(loss)
return np.asarray(steps), np.asarray(losses)
def show_losses(steps, train_losses, test_losses):
fig, ax = plt.subplots(dpi=DPI)
fig.suptitle('Loss curves')
ax.plot(steps, train_losses, label='Train')
ax.plot(steps, test_losses, label='Test')
ax.set_xlabel('Batch')
ax.set_ylabel('Loss')
ax.legend()
plt.show()
def config_mel_ax(ax):
def format_x(value, tick_number):
return '%.1f' % (value * hparams.hop_size / hparams.sample_rate)
mel_freqs = librosa.mel_frequencies(n_mels=hparams.num_mels, fmin=hparams.fmin, fmax=hparams.fmax)
def format_y(value, tick_number):
i = round(value * (hparams.num_mels - 1))
return '%.1f' % mel_freqs[i]
ax.xaxis.set_major_formatter(plt.FuncFormatter(format_x))
ax.set_xlabel('Time (seconds)')
ax.yaxis.set_major_formatter(plt.FuncFormatter(format_y))
ax.set_ylabel('Frequency (Hz)')
def show_mels(mels, titles):
fig, axes = plt.subplots(1, len(mels), dpi=DPI, figsize=(8, 3))
for ax, mel, title in zip(axes, mels, titles):
ax.imshow(mel,
aspect='auto',
cmap='coolwarm',
extent=(0, mel.shape[1], 0, 1),
interpolation='nearest',
origin='lower')
config_mel_ax(ax)
ax.set_title(title)
plt.tight_layout()
plt.show()
def show_generated_examples(paths):
gen_paths = []
for path in paths:
gen_paths.extend(glob.glob(os.path.join(path, '*_gen.wav')))
for i, gen_path in enumerate(gen_paths):
if i != 0:
display(HTML('<hr style="border:1px solid black;">'))
display(HTML('<h3>%.2d</h3>' % (i + 1)))
display(HTML('<p>Generated</p>'))
display(Audio(filename=gen_path))
gen_mel_path = os.path.join('features',
os.path.split(gen_path)[1].replace('_gen.wav', '_gen-feats.npy'))
display(HTML('<p>Test set</p>'))
display(Audio(filename=gen_path.replace('_gen', '_ref')))
ref_mel_filename = os.path.split(gen_path)[-1].replace('_gen.wav', '-feats.npy')
ref_mel_paths = glob.glob(os.path.join('org', '*', ref_mel_filename))
if len(ref_mel_paths) == 0:
raise ValueError('Reference mel spectrogram not found')
elif len(ref_mel_paths) > 1:
raise ValueError('Multiple reference mel spectrograms found')
ref_mel_path = ref_mel_paths[0]
show_mels([np.load(gen_mel_path).T, np.load(ref_mel_path).T],
['Generated', 'Test set'])
steps, train_losses = load_loss('train_no_dev.txt')
_, test_losses = load_loss('dev.txt')
show_losses(steps, train_losses, test_losses)
show_generated_examples(['generated/checkpoint_latest/dev'])#, 'generated/checkpoint_latest/eval'])
Generated
Test set
Generated
Test set
Generated
Test set
Generated
Test set
Generated
Test set
Generated
Test set
Generated
Test set
Generated
Test set
Generated
Test set
Generated
Test set
def compare_mel_vs_chroma(is_log, labeled_freqs=[], freq_range=None):
fig, ax = plt.subplots(dpi=DPI, figsize=(8, 3))
mel_freqs = librosa.mel_frequencies(n_mels=hparams.num_mels, fmin=hparams.fmin, fmax=hparams.fmax)
chroma_freqs = librosa.midi_to_hz(np.arange(47, 120))
for labeled_freq in labeled_freqs:
label = '%.2f Hz' % labeled_freq
ax.axvline(x=np.log(labeled_freq) if is_log else labeled_freq, color='grey',
label=label)
for i, (freqs, name) in enumerate(zip([mel_freqs, chroma_freqs], ['Mel', 'Chromatic'])):
label = '%s (n = %d)' % (name, len(freqs))
ax.plot(np.log(freqs) if is_log else freqs, np.zeros_like(freqs) - i, '.', label=label)
ax.set_xlabel('Log frequency (log Hz)' if is_log else 'Frequency (Hz)')
if freq_range is not None:
ax.set_xlim(freq_range)
ax.set_ylim(-2, 1)
ax.yaxis.set_visible(False)
ax.legend()
plt.tight_layout()
plt.show()
compare_mel_vs_chroma(is_log=True)
compare_mel_vs_chroma(is_log=False, freq_range=(100, 500))